/*
 * Copyright 2000-2017 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file.
 */
package qunar.tc.decompiler.modules.decompiler.stats;

import qunar.tc.decompiler.code.SwitchInstruction;
import qunar.tc.decompiler.code.cfg.BasicBlock;
import qunar.tc.decompiler.main.DecompilerContext;
import qunar.tc.decompiler.main.collectors.BytecodeMappingTracer;
import qunar.tc.decompiler.main.collectors.CounterContainer;
import qunar.tc.decompiler.modules.decompiler.DecHelper;
import qunar.tc.decompiler.modules.decompiler.ExprProcessor;
import qunar.tc.decompiler.modules.decompiler.StatEdge;
import qunar.tc.decompiler.modules.decompiler.SwitchHelper;
import qunar.tc.decompiler.modules.decompiler.exps.ConstExprent;
import qunar.tc.decompiler.modules.decompiler.exps.Exprent;
import qunar.tc.decompiler.modules.decompiler.exps.FieldExprent;
import qunar.tc.decompiler.modules.decompiler.exps.SwitchExprent;
import qunar.tc.decompiler.struct.gen.VarType;
import qunar.tc.decompiler.util.TextBuffer;

import java.util.*;

public class SwitchStatement extends Statement {

    // *****************************************************************************
    // private fields
    // *****************************************************************************

    private List<Statement> caseStatements = new ArrayList<>();

    private List<List<StatEdge>> caseEdges = new ArrayList<>();

    private List<List<Exprent>> caseValues = new ArrayList<>();

    private StatEdge default_edge;

    private final List<Exprent> headexprent = new ArrayList<>(1);

    // *****************************************************************************
    // constructors
    // *****************************************************************************

    private SwitchStatement() {
        type = TYPE_SWITCH;

        headexprent.add(null);
    }

    private SwitchStatement(Statement head, Statement poststat) {

        this();

        first = head;
        stats.addWithKey(head, head.id);

        // find post node
        Set<Statement> lstNodes = new HashSet<>(head.getNeighbours(StatEdge.TYPE_REGULAR, DIRECTION_FORWARD));

        // cluster nodes
        if (poststat != null) {
            post = poststat;
            lstNodes.remove(post);
        }

        default_edge = head.getSuccessorEdges(STATEDGE_DIRECT_ALL).get(0);

        for (Statement st : lstNodes) {
            stats.addWithKey(st, st.id);
        }
    }

    // *****************************************************************************
    // public methods
    // *****************************************************************************

    public static Statement isHead(Statement head) {

        if (head.type == TYPE_BASICBLOCK && head.getLastBasicType() == LASTBASICTYPE_SWITCH) {

            List<Statement> lst = new ArrayList<>();
            if (DecHelper.isChoiceStatement(head, lst)) {
                Statement post = lst.remove(0);

                for (Statement st : lst) {
                    if (st.isMonitorEnter()) {
                        return null;
                    }
                }

                if (DecHelper.checkStatementExceptions(lst)) {
                    return new SwitchStatement(head, post);
                }
            }
        }

        return null;
    }

    @Override
    public TextBuffer toJava(int indent, BytecodeMappingTracer tracer) {
        SwitchHelper.simplify(this);

        TextBuffer buf = new TextBuffer();
        buf.append(ExprProcessor.listToJava(varDefinitions, indent, tracer));
        buf.append(first.toJava(indent, tracer));

        if (isLabeled()) {
            buf.appendIndent(indent).append("label").append(this.id.toString()).append(":").appendLineSeparator();
            tracer.incrementCurrentSourceLine();
        }

        buf.appendIndent(indent).append(headexprent.get(0).toJava(indent, tracer)).append(" {").appendLineSeparator();
        tracer.incrementCurrentSourceLine();

        VarType switch_type = headexprent.get(0).getExprType();

        for (int i = 0; i < caseStatements.size(); i++) {

            Statement stat = caseStatements.get(i);
            List<StatEdge> edges = caseEdges.get(i);
            List<Exprent> values = caseValues.get(i);

            for (int j = 0; j < edges.size(); j++) {
                if (edges.get(j) == default_edge) {
                    buf.appendIndent(indent).append("default:").appendLineSeparator();
                } else {
                    buf.appendIndent(indent).append("case ");
                    Exprent value = values.get(j);
                    if (value instanceof ConstExprent) {
                        value = value.copy();
                        ((ConstExprent) value).setConstType(switch_type);
                    }
                    if (value instanceof FieldExprent && ((FieldExprent) value).isStatic()) { // enum values
                        buf.append(((FieldExprent) value).getName());
                    } else {
                        buf.append(value.toJava(indent, tracer));
                    }

                    buf.append(":").appendLineSeparator();
                }
                tracer.incrementCurrentSourceLine();
            }

            buf.append(ExprProcessor.jmpWrapper(stat, indent + 1, false, tracer));
        }

        buf.appendIndent(indent).append("}").appendLineSeparator();
        tracer.incrementCurrentSourceLine();

        return buf;
    }

    @Override
    public void initExprents() {
        SwitchExprent swexpr = (SwitchExprent) first.getExprents().remove(first.getExprents().size() - 1);
        swexpr.setCaseValues(caseValues);

        headexprent.set(0, swexpr);
    }

    @Override
    public List<Object> getSequentialObjects() {

        List<Object> lst = new ArrayList<Object>(stats);
        lst.add(1, headexprent.get(0));

        return lst;
    }

    @Override
    public void replaceExprent(Exprent oldexpr, Exprent newexpr) {
        if (headexprent.get(0) == oldexpr) {
            headexprent.set(0, newexpr);
        }
    }

    @Override
    public void replaceStatement(Statement oldstat, Statement newstat) {

        for (int i = 0; i < caseStatements.size(); i++) {
            if (caseStatements.get(i) == oldstat) {
                caseStatements.set(i, newstat);
            }
        }

        super.replaceStatement(oldstat, newstat);
    }

    @Override
    public Statement getSimpleCopy() {
        return new SwitchStatement();
    }

    @Override
    public void initSimpleCopy() {
        first = stats.get(0);
        default_edge = first.getSuccessorEdges(STATEDGE_DIRECT_ALL).get(0);

        sortEdgesAndNodes();
    }

    // *****************************************************************************
    // private methods
    // *****************************************************************************

    public void sortEdgesAndNodes() {

        HashMap<StatEdge, Integer> mapEdgeIndex = new HashMap<>();

        List<StatEdge> lstFirstSuccs = first.getSuccessorEdges(STATEDGE_DIRECT_ALL);
        for (int i = 0; i < lstFirstSuccs.size(); i++) {
            mapEdgeIndex.put(lstFirstSuccs.get(i), i == 0 ? lstFirstSuccs.size() : i);
        }

        // case values
        BasicBlockStatement bbstat = (BasicBlockStatement) first;
        int[] values = ((SwitchInstruction) bbstat.getBlock().getLastInstruction()).getValues();

        List<Statement> nodes = new ArrayList<>(stats.size() - 1);
        List<List<Integer>> edges = new ArrayList<>(stats.size() - 1);

        // collect regular edges
        for (int i = 1; i < stats.size(); i++) {

            Statement stat = stats.get(i);

            List<Integer> lst = new ArrayList<>();
            for (StatEdge edge : stat.getPredecessorEdges(StatEdge.TYPE_REGULAR)) {
                if (edge.getSource() == first) {
                    lst.add(mapEdgeIndex.get(edge));
                }
            }
            Collections.sort(lst);

            nodes.add(stat);
            edges.add(lst);
        }

        // collect exit edges
        List<StatEdge> lstExitEdges = first.getSuccessorEdges(StatEdge.TYPE_BREAK | StatEdge.TYPE_CONTINUE);
        while (!lstExitEdges.isEmpty()) {
            StatEdge edge = lstExitEdges.get(0);

            List<Integer> lst = new ArrayList<>();
            for (int i = lstExitEdges.size() - 1; i >= 0; i--) {
                StatEdge edgeTemp = lstExitEdges.get(i);
                if (edgeTemp.getDestination() == edge.getDestination() && edgeTemp.getType() == edge.getType()) {
                    lst.add(mapEdgeIndex.get(edgeTemp));
                    lstExitEdges.remove(i);
                }
            }
            Collections.sort(lst);

            nodes.add(null);
            edges.add(lst);
        }

        // sort edges (bubblesort)
        for (int i = 0; i < edges.size() - 1; i++) {
            for (int j = edges.size() - 1; j > i; j--) {
                if (edges.get(j - 1).get(0) > edges.get(j).get(0)) {
                    edges.set(j, edges.set(j - 1, edges.get(j)));
                    nodes.set(j, nodes.set(j - 1, nodes.get(j)));
                }
            }
        }

        // sort statement cliques
        for (int index = 0; index < nodes.size(); index++) {
            Statement stat = nodes.get(index);

            if (stat != null) {
                HashSet<Statement> setPreds = new HashSet<>(stat.getNeighbours(StatEdge.TYPE_REGULAR, DIRECTION_BACKWARD));
                setPreds.remove(first);

                if (!setPreds.isEmpty()) {
                    Statement pred =
                            setPreds.iterator().next(); // assumption: at most one predecessor node besides the head. May not hold true for obfuscated code.
                    for (int j = 0; j < nodes.size(); j++) {
                        if (j != (index - 1) && nodes.get(j) == pred) {
                            nodes.add(j + 1, stat);
                            edges.add(j + 1, edges.get(index));

                            if (j > index) {
                                nodes.remove(index);
                                edges.remove(index);
                                index--;
                            } else {
                                nodes.remove(index + 1);
                                edges.remove(index + 1);
                            }
                            break;
                        }
                    }
                }
            }
        }

        // translate indices back into edges
        List<List<StatEdge>> lstEdges = new ArrayList<>(edges.size());
        List<List<Exprent>> lstValues = new ArrayList<>(edges.size());

        for (List<Integer> lst : edges) {
            List<StatEdge> lste = new ArrayList<>(lst.size());
            List<Exprent> lstv = new ArrayList<>(lst.size());

            List<StatEdge> lstSuccs = first.getSuccessorEdges(STATEDGE_DIRECT_ALL);
            for (Integer in : lst) {
                int index = in == lstSuccs.size() ? 0 : in;

                lste.add(lstSuccs.get(index));
                lstv.add(index == 0 ? null : new ConstExprent(values[index - 1], false, null));
            }
            lstEdges.add(lste);
            lstValues.add(lstv);
        }

        // replace null statements with dummy basic blocks
        for (int i = 0; i < nodes.size(); i++) {
            if (nodes.get(i) == null) {
                BasicBlockStatement bstat = new BasicBlockStatement(new BasicBlock(
                        DecompilerContext.getCounterContainer().getCounterAndIncrement(CounterContainer.STATEMENT_COUNTER)));

                StatEdge sample_edge = lstEdges.get(i).get(0);

                bstat.addSuccessor(new StatEdge(sample_edge.getType(), bstat, sample_edge.getDestination(), sample_edge.closure));

                for (StatEdge edge : lstEdges.get(i)) {

                    edge.getSource().changeEdgeType(DIRECTION_FORWARD, edge, StatEdge.TYPE_REGULAR);
                    edge.closure.getLabelEdges().remove(edge);

                    edge.getDestination().removePredecessor(edge);
                    edge.getSource().changeEdgeNode(DIRECTION_FORWARD, edge, bstat);
                    bstat.addPredecessor(edge);
                }

                nodes.set(i, bstat);
                stats.addWithKey(bstat, bstat.id);
                bstat.setParent(this);
            }
        }

        caseStatements = nodes;
        caseEdges = lstEdges;
        caseValues = lstValues;
    }

    public List<Exprent> getHeadexprentList() {
        return headexprent;
    }

    public Exprent getHeadexprent() {
        return headexprent.get(0);
    }

    public List<List<StatEdge>> getCaseEdges() {
        return caseEdges;
    }

    public List<Statement> getCaseStatements() {
        return caseStatements;
    }

    public StatEdge getDefault_edge() {
        return default_edge;
    }

    public List<List<Exprent>> getCaseValues() {
        return caseValues;
    }
}
